# ------------------------------------------------------------------------------
# Fits LASSO stent model
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=25000]" bash 02_fit_lassos.sh
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(glue)
library(Matrix)
library(doParallel)
library(glmnet)
library(optparse)
library(here)
library(testit)

u <- modules::use(here("lib", "util.R"))
temp <- here("code", "08_train_split_model", "temp")

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here("lib", "filepaths.yml"))
x <- readRDS(paths$features$train_tested)
ids <- readRDS(file.path(temp, "split_train_cohort.rds"))

# Subset Data ------------------------------------------------------------------
message("Subsetting data...")
keep_pop <- !ids$excl_flag_c_int & !ids$excl_flag_chronic & !ids$excl_flag_death
keep <- which(keep_pop)

x <- x[keep, ]
ids <- ids[keep, ]

# Tuning -----------------------------------------------------------------------
message("Fitting lassos...")
for(i in c(1,2)){
  message("Split ", i)
  split_var <- glue("sample_split{i}")
  keep_split <- which(ids[[split_var]])

  x_split <- x[keep_split,]
  ids_split <- ids[keep_split,]

  registerDoParallel(cores = uniqueN(ids$train_fold))
  tuning_result <- cv.glmnet(
    x = x_split,
    y = ids_split$stent_or_cabg_010_day,
    foldid = ids_split$train_fold,
    parallel = TRUE,
    family = "binomial",
    type.measure = "auc",
    alpha = 1
  )
  stopImplicitCluster()

  message("Saving model ", i, "...")
  save_fp <- file.path(temp, glue("lasso__stent_or_cabg_010_day__tested__{i}.rds"))
  saveRDS(tuning_result, save_fp)
}

# Done -------------------------------------------------------------------------
message("Done.")
